import datetime
import math
import os
import uuid
import arcpy
import yaml
import ZLog
import numpy as np
import pandas as pd

class n_sedi_50_59_total:
    def __init__(self) -> None:
        self.loggor = ZLog.ZLog()

    def logDebug(self, msg, c=True):
        if (c):
            self.loggor.cDebug(msg)
        else:
            self.loggor.debug(msg)
    
    def logInfo(self, msg, c=True):
        if (c):
            self.loggor.cInfo(msg)
        else:
            self.loggor.info(msg)

    def logErr(self, msg, c=True):
        if (c):
            self.loggor.cError(msg)
        else:
            self.loggor.error(msg)

    def createTempGDB(self, root):
        if not os.path.exists(root):
            os.makedirs(root)
        name = str(uuid.uuid4())
        arcpy.CreateFileGDB_management(root, name)
        return root + "/" + name + ".gdb"

    def rasterToPolygon(self, raster_path, out_polygon_path):
        arcpy.conversion.RasterToPolygon(in_raster=raster_path, out_polygon_features=out_polygon_path,
                                          simplify='NO_SIMPLIFY', raster_field='Value', 
                                          create_multipart_features='SINGLE_OUTER_PART')

    def extractResidentPolygon_50_59(self, input_polygon_path, output_polygon_path):
        arcpy.analysis.Select(in_features=input_polygon_path, out_feature_class=output_polygon_path,
                               where_clause='gridcode >= 50 And gridcode < 60')

    def getPolygonArea(self, poly_path):
        total_area = 0.0
        with arcpy.da.SearchCursor(poly_path, ["SHAPE@", "SHAPE@AREA"]) as cursor:
            for row in cursor:
                feature = row[0]
                area = row[1]
                total_area += area
        return total_area

    def createFishnet(self, output_path, raster_path):
        desc = arcpy.Describe(raster_path)
        extent = desc.extent
        arcpy.management.CreateFishnet(out_feature_class=output_path, origin_coord=str(extent.XMin)+' '+str(extent.YMin), 
                                       y_axis_coord=str(extent.XMin)+' '+str(extent.YMin + 10), 
                                       cell_width=1000, cell_height=1000, 
                                       corner_coord=str(extent.XMax)+' '+str(extent.YMax),
                                       labels='NO_LABELS', template=extent, geometry_type='POLYGON')

    def filterFishnet(self, fishnet_path, filter_path, output_path):
        arcpy.analysis.SpatialJoin(target_features=fishnet_path, join_features=filter_path, out_feature_class=output_path, join_operation='JOIN_ONE_TO_ONE', join_type='KEEP_COMMON')    
    
    def pairwiseClip(self, in_features, clip_features, out_feature_class):
        arcpy.analysis.PairwiseClip(in_features=in_features, clip_features=clip_features, out_feature_class=out_feature_class)

    def spatialJoin(self, target_features, join_features, out_feature_class):
        arcpy.analysis.SpatialJoin(target_features=target_features, join_features=join_features, out_feature_class=out_feature_class, 
                                   join_operation='JOIN_ONE_TO_ONE', join_type='KEEP_ALL', field_mapping='Shape_Length "Shape_Length" false true true 8 Double 0 0,First,#,fishnet_1000,Shape_Length,-1,-1;Shape_Area "Shape_Area" false true true 8 Double 0 0,First,#,fishnet_1000,Shape_Area,-1,-1;Shape_Length_1 "Shape_Length" false true true 8 Double 0 0,First,#,raster_polygon_clip,Shape_Length,-1,-1;Shape_Area_1 "Shape_Area" false true true 8 Double 0 0,First,#,raster_polygon_clip,Shape_Area,-1,-1', match_option='CONTAINS')

    def featureToPoint(self, in_features, out_feature_class):
        arcpy.management.FeatureToPoint(in_features, out_feature_class)

    def calculateP(self, gdb_path, feature_class_name, total_area):
        pnt_arr = []
        with arcpy.da.SearchCursor(gdb_path + "\\" + feature_class_name, ["OID@", "SHAPE@XY", "Shape_Area_1"]) as cursor:
            for row in cursor:
                object_id = row[0]
                x, y = row[1]
                c_area = 0.0
                if(row[2] is not None):
                    c_area = row[2]
                p_val = c_area / total_area
                pnt_arr.append([object_id, p_val, x, y])
                self.logDebug('oid: {:.0f}\tp: {:.28f}\t x: {:.10f}\ty: {:.10f}'.format(object_id, p_val, x, y), False)
        return pnt_arr
    
    def calculateDistanceMatrix(self, pnt_arr):
        coordinates = np.array(pnt_arr)[:, 2:]
        distance_matrix = np.sqrt(((coordinates[:, np.newaxis] - coordinates) ** 2).sum(axis=2))
        result_array = []
        for i in range(len(pnt_arr)):
            for j in range(i+1, len(pnt_arr)):
                oid_source = pnt_arr[i][0]
                p_source = pnt_arr[i][1]
                oid_target = pnt_arr[j][0]
                p_target = pnt_arr[j][1]
                distance = distance_matrix[i, j]
                result_array.append([oid_source, oid_target, p_source, p_target, distance])
        return result_array

    def calculateSqrt2dis(self, distance):
        return math.sqrt(2 * distance)

    def calculateWij(self, a, b):
        if(a!=0 and b!=0):
            return -(a+b) *(a/(a+b) *math.log(a/(a+b),2 ) + b /(a+b) *math.log(b/(a+b) ,2 ) )
        else:
            return 0
    
    def calculateWij_multiply_sqrt2dis(self, wij, sqrt2dis):
        return wij * sqrt2dis
    
    def calculateSedi(self, SUM_wij_multiply_sqrt2dis, SUM_sqrt2dis, COUNT_OBJECTID):
        return SUM_wij_multiply_sqrt2dis/(2*SUM_sqrt2dis/COUNT_OBJECTID)
    
    def saveResult(self, file_path, content):
        folder = os.path.dirname(file_path)
        if not os.path.exists(folder):
                os.makedirs(folder)
        with open(file_path, 'w') as file:
            file.write(content)
        
    def run(self, raster_path):
        try:
            gdb_path = self.createTempGDB("./temp")
            self.logInfo("tempgdb: " + gdb_path, False)
            # raster_path = self.raster_path
            self.logInfo("raster data: " + raster_path)
            # 5s195ms
            self.logInfo("sta - raster to polygon...")
            ras_to_poly_path = gdb_path + '/raster_polygon'
            self.rasterToPolygon(raster_path=raster_path, out_polygon_path=ras_to_poly_path)
            # 1s463ms
            resident_poly_path = gdb_path + '/resident_polygon'
            self.extractResidentPolygon_50_59(input_polygon_path=ras_to_poly_path, output_polygon_path=resident_poly_path)
            # 1s474ms
            self.logInfo("sta - create fishnet...")
            fishnet_path = gdb_path + '/fishnet_1000'
            self.createFishnet(output_path=fishnet_path, raster_path=raster_path)
            # 11s688ms
            self.logInfo("sta - filter fishnet...")
            fishnet_filter_path = gdb_path + '/fishnet_1000_filter'
            self.filterFishnet(fishnet_path=fishnet_path, filter_path=ras_to_poly_path, output_path=fishnet_filter_path)
            # 2s600ms
            self.logInfo("sta - pairwise clip...")
            resident_poly_clip_path = gdb_path + '/raster_polygon_clip'
            self.pairwiseClip(in_features=fishnet_filter_path, clip_features=resident_poly_path, out_feature_class=resident_poly_clip_path)
            # 2s975ms
            self.logInfo("sta - spatial join...")
            fishnet_join_path = gdb_path + '/fishnet_1000_join'
            self.spatialJoin(target_features=fishnet_filter_path, join_features=resident_poly_clip_path, out_feature_class=fishnet_join_path)
            # 1s443ms
            self.logInfo("sta - feature to point...")
            data_pnt_path = gdb_path + '/data_pnt_join'
            self.featureToPoint(in_features=fishnet_join_path, out_feature_class=data_pnt_path)
            # 1s65ms
            self.logInfo("sta - get workspace area...")
            sum_area = self.getPolygonArea(ras_to_poly_path)
            # 148ms
            self.logInfo("sta - calculate P...")
            pnt_arr = self.calculateP(gdb_path, 'data_pnt_join', total_area=sum_area)
            # 2m6s560ms
            self.logInfo("sta - calculate distance matrix...")
            distance_arr = self.calculateDistanceMatrix(pnt_arr)
            # 1m33s276ms
            self.logInfo("sta - create dataframe...")
            df_m = pd.DataFrame(distance_arr, columns=['oid_source', 'oid_target', 'p_source', 'p_target', 'distance'])
            # 43s238ms
            self.logInfo("sta - calculate sqrt2dis...")
            df_m['sqrt2dis'] = df_m['distance'].apply(self.calculateSqrt2dis)
            # 2h0m31s607ms
            self.logInfo("sta - calculate wij...")
            df_m['wij'] = df_m.apply(
                lambda row: self.calculateWij(row['p_source'], row['p_target']), axis=1)
            # 25m28s612ms
            self.logInfo("sta - calculate wij_multiply_sqrt2dis...")
            df_m['wij_multiply_sqrt2dis'] = df_m.apply(lambda r: self.calculateWij_multiply_sqrt2dis(r['wij'], r['sqrt2dis']), axis=1)
            # 
            self.logInfo('fin - calculate result...')
            sum_sqrt2dis = df_m['sqrt2dis'].sum()
            sum_wij_multiply_sqrt2dis = df_m['wij_multiply_sqrt2dis'].sum()
            count_objectid =  len(pnt_arr)
            sedi = self.calculateSedi(sum_wij_multiply_sqrt2dis, sum_sqrt2dis, count_objectid)
            result = 'sum_sqrt2dis: {:.12f}\nsum_wij_multiply_sqrt2dis: {:.12f}\ncount_objectid: {:.0f}\nsedi: {:.12f}'.format(
                sum_sqrt2dis, sum_wij_multiply_sqrt2dis, count_objectid, sedi
            )
            timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S%f")
            self.saveResult(file_path='./result/{0}.txt'.format(timestamp), content=result)
            self.logInfo('fin - save calculate result ./result/{0}.txt\n'.format(timestamp) + result)
        except Exception as ex:
            self.logErr(ex)

# raster data path
raster_path_arr=[
    r'D:\file\project\linYi\file\beijing1980',
    r'D:\file\project\linYi\file\beijing1990',
    r'D:\file\project\linYi\file\beijing1995',
    r'D:\file\project\linYi\file\beijing2000',
    r'D:\file\project\linYi\file\beijing2005',
    r'D:\file\project\linYi\file\beijing2010',
    r'D:\file\project\linYi\file\beijing2015',
    r'D:\file\project\linYi\file\beijing2020'
]

for rp in raster_path_arr:
    ns = None
    ns = n_sedi_50_59_total()
    ns.run(rp)